{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "实验项目内容：逻辑回归算法的编程与实现  \n",
    "实验内容：基于逻辑回归算法实现digits数据集分类   \n",
    "要求：  \n",
    "    （1）构造逻辑回归多目标分类器  \n",
    "    （2）给出训练集、测试集生产代码  \n",
    "    （3）给出分类器训练和预测代码  \n",
    "    （4）评估分类器模型效果（正确率、精确率、召回率、混淆矩阵）  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets,linear_model\n",
    "from sklearn.metrics import precision_score\n",
    "import pandas as pd\n",
    "from pandas import DataFrame\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入iris数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 获取相应的属性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.4, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [1.3, 0.2],\n",
       "       [1.5, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [1.7, 0.4],\n",
       "       [1.4, 0.3],\n",
       "       [1.5, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [1.5, 0.1],\n",
       "       [1.5, 0.2],\n",
       "       [1.6, 0.2],\n",
       "       [1.4, 0.1],\n",
       "       [1.1, 0.1],\n",
       "       [1.2, 0.2],\n",
       "       [1.5, 0.4],\n",
       "       [1.3, 0.4],\n",
       "       [1.4, 0.3],\n",
       "       [1.7, 0.3],\n",
       "       [1.5, 0.3],\n",
       "       [1.7, 0.2],\n",
       "       [1.5, 0.4],\n",
       "       [1. , 0.2],\n",
       "       [1.7, 0.5],\n",
       "       [1.9, 0.2],\n",
       "       [1.6, 0.2],\n",
       "       [1.6, 0.4],\n",
       "       [1.5, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [1.6, 0.2],\n",
       "       [1.6, 0.2],\n",
       "       [1.5, 0.4],\n",
       "       [1.5, 0.1],\n",
       "       [1.4, 0.2],\n",
       "       [1.5, 0.2],\n",
       "       [1.2, 0.2],\n",
       "       [1.3, 0.2],\n",
       "       [1.4, 0.1],\n",
       "       [1.3, 0.2],\n",
       "       [1.5, 0.2],\n",
       "       [1.3, 0.3],\n",
       "       [1.3, 0.3],\n",
       "       [1.3, 0.2],\n",
       "       [1.6, 0.6],\n",
       "       [1.9, 0.4],\n",
       "       [1.4, 0.3],\n",
       "       [1.6, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [1.5, 0.2],\n",
       "       [1.4, 0.2],\n",
       "       [4.7, 1.4],\n",
       "       [4.5, 1.5],\n",
       "       [4.9, 1.5],\n",
       "       [4. , 1.3],\n",
       "       [4.6, 1.5],\n",
       "       [4.5, 1.3],\n",
       "       [4.7, 1.6],\n",
       "       [3.3, 1. ],\n",
       "       [4.6, 1.3],\n",
       "       [3.9, 1.4],\n",
       "       [3.5, 1. ],\n",
       "       [4.2, 1.5],\n",
       "       [4. , 1. ],\n",
       "       [4.7, 1.4],\n",
       "       [3.6, 1.3],\n",
       "       [4.4, 1.4],\n",
       "       [4.5, 1.5],\n",
       "       [4.1, 1. ],\n",
       "       [4.5, 1.5],\n",
       "       [3.9, 1.1],\n",
       "       [4.8, 1.8],\n",
       "       [4. , 1.3],\n",
       "       [4.9, 1.5],\n",
       "       [4.7, 1.2],\n",
       "       [4.3, 1.3],\n",
       "       [4.4, 1.4],\n",
       "       [4.8, 1.4],\n",
       "       [5. , 1.7],\n",
       "       [4.5, 1.5],\n",
       "       [3.5, 1. ],\n",
       "       [3.8, 1.1],\n",
       "       [3.7, 1. ],\n",
       "       [3.9, 1.2],\n",
       "       [5.1, 1.6],\n",
       "       [4.5, 1.5],\n",
       "       [4.5, 1.6],\n",
       "       [4.7, 1.5],\n",
       "       [4.4, 1.3],\n",
       "       [4.1, 1.3],\n",
       "       [4. , 1.3],\n",
       "       [4.4, 1.2],\n",
       "       [4.6, 1.4],\n",
       "       [4. , 1.2],\n",
       "       [3.3, 1. ],\n",
       "       [4.2, 1.3],\n",
       "       [4.2, 1.2],\n",
       "       [4.2, 1.3],\n",
       "       [4.3, 1.3],\n",
       "       [3. , 1.1],\n",
       "       [4.1, 1.3],\n",
       "       [6. , 2.5],\n",
       "       [5.1, 1.9],\n",
       "       [5.9, 2.1],\n",
       "       [5.6, 1.8],\n",
       "       [5.8, 2.2],\n",
       "       [6.6, 2.1],\n",
       "       [4.5, 1.7],\n",
       "       [6.3, 1.8],\n",
       "       [5.8, 1.8],\n",
       "       [6.1, 2.5],\n",
       "       [5.1, 2. ],\n",
       "       [5.3, 1.9],\n",
       "       [5.5, 2.1],\n",
       "       [5. , 2. ],\n",
       "       [5.1, 2.4],\n",
       "       [5.3, 2.3],\n",
       "       [5.5, 1.8],\n",
       "       [6.7, 2.2],\n",
       "       [6.9, 2.3],\n",
       "       [5. , 1.5],\n",
       "       [5.7, 2.3],\n",
       "       [4.9, 2. ],\n",
       "       [6.7, 2. ],\n",
       "       [4.9, 1.8],\n",
       "       [5.7, 2.1],\n",
       "       [6. , 1.8],\n",
       "       [4.8, 1.8],\n",
       "       [4.9, 1.8],\n",
       "       [5.6, 2.1],\n",
       "       [5.8, 1.6],\n",
       "       [6.1, 1.9],\n",
       "       [6.4, 2. ],\n",
       "       [5.6, 2.2],\n",
       "       [5.1, 1.5],\n",
       "       [5.6, 1.4],\n",
       "       [6.1, 2.3],\n",
       "       [5.6, 2.4],\n",
       "       [5.5, 1.8],\n",
       "       [4.8, 1.8],\n",
       "       [5.4, 2.1],\n",
       "       [5.6, 2.4],\n",
       "       [5.1, 2.3],\n",
       "       [5.1, 1.9],\n",
       "       [5.9, 2.3],\n",
       "       [5.7, 2.5],\n",
       "       [5.2, 2.3],\n",
       "       [5. , 1.9],\n",
       "       [5.2, 2. ],\n",
       "       [5.4, 2.3],\n",
       "       [5.1, 1.8]])"
      ]
     },
     "execution_count": 235,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = iris.data[:,2:4]\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
      ]
     },
     "execution_count": 236,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = iris.target\n",
    "y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成训练集和测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 237,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.39,random_state=150)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 设置网格步长"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [],
   "source": [
    "h = .02"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建模型对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(max_iter=1000, multi_class='ovr', solver='sag')"
      ]
     },
     "execution_count": 239,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg_model = linear_model.LogisticRegression(max_iter=1000,multi_class=\"ovr\",solver=\"sag\")\n",
    "log_reg_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(max_iter=1000, multi_class='ovr', solver='sag')"
      ]
     },
     "execution_count": 240,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg_model.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 2, 2, 1, 0, 0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 2, 2, 2, 2, 1, 1,\n",
       "       0, 0, 1, 1, 0, 2, 2, 1, 2, 1, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 1, 0,\n",
       "       2, 1, 0, 0, 0, 0, 2, 2, 0, 2, 1, 2, 0, 1, 1])"
      ]
     },
     "execution_count": 241,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = log_reg_model.predict(X_test)\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型评分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 242,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9491525423728814"
      ]
     },
     "execution_count": 242,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "srm=log_reg_model.score(X_test,y_test)\n",
    "srm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 正确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9491525423728814\n",
      "逻辑回归算法的分类正确率是:94.92%\n"
     ]
    }
   ],
   "source": [
    "acc = accuracy_score(y_true=y_test,y_pred=y_pred)\n",
    "print(acc)\n",
    "print(\"逻辑回归算法的分类正确率是:%0.2f%%\"%(100*acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 精确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9491525423728814\n",
      "逻辑回归算法的分类正确率是:94.92%\n"
     ]
    }
   ],
   "source": [
    "acc2 = precision_score(y_true=y_test,y_pred=y_pred,average='micro')\n",
    "print(acc2)\n",
    "print(\"逻辑回归算法的分类正确率是:%0.2f%%\"%(100*acc2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 召回率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9491525423728814\n",
      "逻辑回归算法的召回率是:94.92%\n"
     ]
    }
   ],
   "source": [
    "recall=recall_score(y_true=y_test,y_pred=y_pred,average='micro')\n",
    "print(recall)\n",
    "print(\"逻辑回归算法的召回率是:%0.2f%%\"%(100*recall))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 混淆矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[23,  0,  0],\n",
       "       [ 0, 17,  3],\n",
       "       [ 0,  0, 16]], dtype=int64)"
      ]
     },
     "execution_count": 246,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_true=y_test,y_pred=y_pred)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root] *",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
